import os
import click
import torch
import torch.distributed as dist
import yaml
import torchvision.utils

import utils.graph_lib
from models.model_utils import get_model, get_preconditioned_model
from utils.misc import dotdict
from utils.samplers import get_pc_sampler
from utils.guidance_schedules import get_guidance_schedule



@click.command()
@click.option('--model',type=click.Choice(['uvit']), default='uvit')
@click.option('--num_samples',type=int,default=100)
@click.option('--batch_size', type=int, default=50)
@click.option('--num_steps',type=int,default=100)
@click.option('--w',type=float,default=1.)
@click.option('--seed',type=int,default=42)
@click.option('--dir',type=str)
@click.option('--net_config_path',type=str, default='configs/uvit.yaml')
@click.option('--load_checkpoint',type=str, help='Directory where we can find the desired checkpoints')
def sampling(**opts):
    opts = dotdict(opts)
    batch_size = opts.batch_size
    
    dist.init_process_group('nccl')
    world_size = dist.get_world_size()
    assert batch_size % world_size == 0, "Batch size must be divisible by world size."
    rank = dist.get_rank()
    device = rank % torch.cuda.device_count()
    seed = opts.seed * world_size + rank
    torch.manual_seed(seed)
    torch.cuda.set_device(device)
    print(f"Starting rank={rank}, seed={seed}, world_size={world_size}.")

    
    net_opts = dotdict(yaml.safe_load(open(opts.net_config_path)))
    if rank == 0:
        print(opts)
        print(net_opts)
        
    
    vocab_size = 256
    context_len = 784
    graph = utils.graph_lib.Absorbing(vocab_size)

    model = get_model(opts.model,vocab_size + 1, context_len, net_opts)
    model = get_preconditioned_model(model,graph).to(device)
    
    assert opts.load_checkpoint is not None, 'A checkpoint is required'
    load_checkpoint(opts, rank, model)

    dist.barrier()
    
    model.eval()
    
    if rank == 0:
        print(f"Model parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)//1e6} M")
    
    if not os.path.exists(opts.dir) and rank == 0:
        os.makedirs(opts.dir)

    batch_size = opts.batch_size
    num_iters = opts.num_samples // batch_size
    for iter in range(num_iters):
        path_samples = os.path.join(opts.dir,f'rank_{rank}/Batch_{iter}')
        os.makedirs(path_samples,exist_ok=True)
        
        labels = torch.randint(0, 10, (batch_size,1), device=device)
        labels = torch.ones_like(labels) * 8
        guid_schedule = get_guidance_schedule('constant', scale=opts.w)
        encoded_im = get_pc_sampler(model, (batch_size,context_len), labels, opts.num_steps, 
                            guidance_schedule=guid_schedule,
                            device=device, 
                            graph=graph, 
                            force_condition_class=None)

        encoded_im = encoded_im.reshape(-1, 1, 28, 28)

        grid = torchvision.utils.make_grid(encoded_im.cpu().float() / 255.0, 
                                          nrow=int(batch_size**0.5),  # Make a square-ish grid
                                          padding=2, 
                                          normalize=False)
        torchvision.utils.save_image(grid, os.path.join(path_samples, f'samples_grid_{opts.w}.png'))
        
        # for i in range(batch_size):
        #     img = encoded_im[i].cpu().float() / 255.0  # Normalize to [0,1]
        #     torchvision.utils.save_image(img, os.path.join(path_samples, f'sample_{i}.png'))
                


    dist.barrier()
    dist.destroy_process_group()


def load_checkpoint(opts, rank, model):
    print(f'Loading checkpoint from {opts.load_checkpoint} in rank {rank}')
    snapshot = torch.load(os.path.join(opts.load_checkpoint), weights_only=True)
    model.net.load_state_dict(snapshot['model'],strict=False)

def save_ckpt(model, ema, opt, scheduler, path):
    snapshot = {
                    'model': model.module.net.state_dict(),
                    'ema': ema.net.state_dict(),
                    'optimizer': opt.state_dict(),
                    'scheduler': scheduler.state_dict()
                }
    torch.save(snapshot,path)

        
if __name__ == '__main__':
    sampling()